How to Implement Your Own Estimators

All the estimators in all packages follow the pre-defined protocols based on their types. All the implementations of algorithms which follow the protocols in s3l.base can be evaluated as the built-in algorithms by experiment classes.

The estimators should inherit a base estimator class in s3l.base according to the type of the estimator you are going to implement. We currently provide five options for you:

  1. TransductiveEstimatorwithGraph,
  2. TransductiveEstimatorWOGraph,
  3. InductiveEstimatorWOGraph,
  4. InductiveEstimatorwithGraph,
  5. SupervisedEstimator.

As the names indicate, the experiments support supervised learning algorithms, semi-supervised learning algorithms in both inductive and transductive settings with or without graph.

For each estimator class, you must implement the following methods: set_params, fit and predict.

set_params is the methods to configure the parameters of the estimator objects given a dict storing the values of some parameters. It’s called in the experiments to search for the best hyper-parameters. Since the object is used repeatly with different hyper-parameters, you should make sure that the object is reset as if hadn’t been trained. A common implementation is as follows.

def set_params(self, param):
        """Parameter setting function.

        Parameters
        ----------
        param:dict
            Store parameter names and corresponding values {'name': value}.
        """
        if isinstance(param, dict):
            self.__dict__.update(param)

        # Codes to reset some properties which may influence the
        # prediction.

fit is the method to train the model given data; predict is the method to make prediction. The main difference between base classes is the parameters of the fit and predict. For transductive estimator, the predict method takes in the indexes of instances to predict (the estimator can see the testing data when training). For inductive estimator, the predict method takes in the features of instances to predict. fit method always takes X, y, l_ind, and optional args are supported. For graph-based algorithms, W must be provided for fit method.

For supervised learning algorithm, you can inherit SupervisedEstimator class. You must rewrite __init__ method and initialize the member model as an object of supervised learning model, and model must have the following methods:

class SupervisedEstimator(BaseEstimator):
    """ Supervised estimator of single-label task.
    """

    @abstractmethod
    def __init__(self):
        super(SupervisedEstimator, self).__init__()
        self.model = None

    def fit(self, X, y, l_ind=None, **kwargs):
        """
        Takes X, y, label_index.
        """
        if l_ind is not None:
            X = X[l_ind, :]
            if y.ndim == 2:
                y = y[l_ind, :].reshape(-1)
            else:
                y = y[l_ind]
        self.model.fit(X, y)

    def predict(self, X, **kwargs):
        """
        Takes X
        """
        return self.model.predict(X)

    def set_params(self, param):
        self.model.set_params(**param)

    def predict_proba(self, X):
        return self.model.predict_proba(X)

    def predict_log_proba(self, X):
        return self.model.predict_log_proba(X)

s3l.wrapper.sklearn_wrapper guides you to wrap any supervised learning algorithm you like.

Attention

Sometimes your estimator class may contain C-language object member. The object of estimator can be un-serializable when the C object has pointers because the python interpreter has no way to know the details of the memory where the pointer points to.

The experiment classes run the experiemnts in multi-process mode when n_jobs is set larger than 1, which requires the estimator object is serializable. An option is to rewrite the __getstate__ and __setstate__ methods to design the way how estimator object is dumped and loaded by pickle. The simplest way is to drop the un-picklable member in __getstate__ and re-initialze it in __setstate__. Here is an example taken from s3l.classification.TSVM where self.model is a C object:

def __getstate__(self):
    """
    The model is ctypes objects and contains pointers cannot be pickled.
    So we drop the model when we pickle TSVM.
    """
    state = self.__dict__.copy()
    del state['model']  # manually delete
    return state

def __setstate__(self, state):
    """
    The model is ctypes objects and contains pointers cannot be pickled.
    So we drop the model when we pickle TSVM.
    """
    self.__dict__.update(state)
    self.model = None  # manually update